{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a8ee28d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from multiprocess import Pool\n",
    "import itertools\n",
    "import numpy as np\n",
    "\n",
    "def chunks(l, n):\n",
    "    for i in range(0, len(l), n):\n",
    "        yield (l[i: i + n], i // n)\n",
    "\n",
    "def multiprocessing(strings, function, cores=6, returned=True):\n",
    "    df_split = chunks(strings, len(strings) // cores)\n",
    "    pool = Pool(cores)\n",
    "    pooled = pool.map(function, df_split)\n",
    "    pool.close()\n",
    "    pool.join()\n",
    "\n",
    "    if returned:\n",
    "        return list(itertools.chain(*pooled))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2fb31c7c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0a09c5ae1a894beb88ff36399b924436",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "README.md:   0%|          | 0.00/459 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/lib/python3/dist-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.4\n",
      "  warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "197936ef682840f58b5f07f8e6f2359c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "train-00000-of-00006.parquet:   0%|          | 0.00/442M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1f9eed511ad14fe3b69c6d8784dcc110",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "train-00001-of-00006.parquet:   0%|          | 0.00/427M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5fed5752fbb64b629f41d73300d71933",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "train-00002-of-00006.parquet:   0%|          | 0.00/431M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8b981ed5aa984d70a1997c9dc4c52a5c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "train-00003-of-00006.parquet:   0%|          | 0.00/444M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "eab60bf53732421ba15b322accaa66f0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "train-00004-of-00006.parquet:   0%|          | 0.00/436M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dd9a3f5cb82d469c82cbcb4d3e2b60ef",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "train-00005-of-00006.parquet:   0%|          | 0.00/431M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c20507c53447412387a9ea03815f681e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split:   0%|          | 0/1584 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['audio', 'speakers', 'timestamps_start', 'timestamps_end'],\n",
      "        num_rows: 1584\n",
      "    })\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "ds = load_dataset(\"diarizers-community/synthetic-speaker-diarization-dataset\")\n",
    "\n",
    "print(ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fb88cb25",
   "metadata": {},
   "outputs": [],
   "source": [
    "import string\n",
    "import soundfile as sf\n",
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "\n",
    "def convert_rttm(chunk, filename = 'audio'):\n",
    "    rttm = []\n",
    "    for start, end, speaker in chunk:\n",
    "        duration = end - start\n",
    "        rttm.append(f\"SPEAKER {filename} 1 {start:.4f} {duration:.4f} <NA> <NA> <NA> <NA> {speaker}\")\n",
    "    return '\\n'.join(rttm)\n",
    "\n",
    "def convert_textgrid(segments):\n",
    "    tiers = defaultdict(list)\n",
    "    for start, end, speaker in segments:\n",
    "        tiers[speaker].append((start, end))\n",
    "\n",
    "    min_time = min(start for start, _, _ in segments)\n",
    "    max_time = max(end for _, end, _ in segments)\n",
    "\n",
    "    textgrid = []\n",
    "    textgrid.append(\"File type = \\\"ooTextFile\\\"\")\n",
    "    textgrid.append(\"Object class = \\\"TextGrid\\\"\")\n",
    "    textgrid.append(\"\")\n",
    "    textgrid.append(f\"xmin = {min_time:.2f}\")\n",
    "    textgrid.append(f\"xmax = {max_time:.2f}\")\n",
    "    textgrid.append(\"tiers? <exists>\")\n",
    "    textgrid.append(f\"size = {len(tiers)}\")\n",
    "    textgrid.append(\"item []:\")\n",
    "\n",
    "    for i, (speaker, intervals) in enumerate(tiers.items(), start=1):\n",
    "        textgrid.append(f\"    item [{i}]:\")\n",
    "        textgrid.append(\"        class = \\\"IntervalTier\\\"\")\n",
    "        textgrid.append(f\"        name = \\\"{speaker}\\\"\")\n",
    "        textgrid.append(f\"        xmin = {min_time:.2f}\")\n",
    "        textgrid.append(f\"        xmax = {max_time:.2f}\")\n",
    "        textgrid.append(f\"        intervals: size = {len(intervals)}\")\n",
    "\n",
    "        for j, (start, end) in enumerate(intervals, start=1):\n",
    "            textgrid.append(f\"        intervals [{j}]:\")\n",
    "            textgrid.append(f\"            xmin = {start:.2f}\")\n",
    "            textgrid.append(f\"            xmax = {end:.2f}\")\n",
    "            textgrid.append(f\"            text = \\\"{speaker}\\\"\")\n",
    "            \n",
    "    return '\\n'.join(textgrid)\n",
    "\n",
    "timestamps = [i * 0.02 for i in range(1500 + 1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2755f299",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !rm -rf synthetic-speaker-diarization-dataset\n",
    "!mkdir synthetic-speaker-diarization-dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3987a18d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import os\n",
    "\n",
    "def loop(indices):\n",
    "    indices, _ = indices\n",
    "    ds = load_dataset(\"diarizers-community/synthetic-speaker-diarization-dataset\")\n",
    "    data = []\n",
    "    for k, key in tqdm(indices):\n",
    "        row = ds[key][k]\n",
    "        audio = row['audio']['array']\n",
    "        chunks, temp = [], []\n",
    "        argsort = np.argsort(row['timestamps_start'])\n",
    "        timestamps_start = [row['timestamps_start'][i] for i in argsort]\n",
    "        timestamps_end = [row['timestamps_end'][i] for i in argsort]\n",
    "        speakers = [row['speakers'][i] for i in argsort]\n",
    "        start = timestamps_start[0]\n",
    "        max_len = 30\n",
    "        for i in range(len(timestamps_start)):\n",
    "            l = timestamps_end[i] - start\n",
    "            if l >= max_len:\n",
    "                chunks.append(temp)\n",
    "                temp = [[timestamps_start[i], timestamps_end[i], speakers[i]]]\n",
    "                start = timestamps_start[i]\n",
    "                continue\n",
    "            else:\n",
    "                temp.append([timestamps_start[i], timestamps_end[i], speakers[i]])\n",
    "\n",
    "        if len(temp):\n",
    "            chunks.append(temp)\n",
    "\n",
    "        for no, chunk in enumerate(chunks):\n",
    "            speakers = []\n",
    "            for i in range(len(chunk)):\n",
    "                if chunk[i][-1] not in speakers:\n",
    "                    speakers.append(chunk[i][-1])\n",
    "            \n",
    "            try:          \n",
    "                start_time = chunk[0][0]\n",
    "                end_time = max([c[1] for c in chunk])\n",
    "            except Exception as e:\n",
    "                continue\n",
    "                \n",
    "            if round(end_time - start_time, 2) > max_len:\n",
    "                continue\n",
    "            \n",
    "            y = audio[int(16000 * start_time): int(16000 * end_time)]\n",
    "            audio_filename = f'synthetic-speaker-diarization-dataset/{key}-{k}-{no}.mp3'\n",
    "            if not os.path.exists(audio_filename):\n",
    "                sf.write(audio_filename, y, 16000)\n",
    "            \n",
    "            ts = []\n",
    "            for i in range(len(chunk)):\n",
    "                index = speakers.index(chunk[i][-1])\n",
    "                start = min(timestamps, key=lambda t: abs(t - (chunk[i][0] - start_time)))\n",
    "                end = min(timestamps, key=lambda t: abs(t - (chunk[i][1] - start_time)))\n",
    "                speaker_name = f'speaker {string.ascii_uppercase[index]}'\n",
    "                chunk[i][-1] = speaker_name\n",
    "                chunk[i][0] = start\n",
    "                chunk[i][1] = end\n",
    "                t = f\"<|{start:.2f}|> {speaker_name}<|{end:.2f}|>\"\n",
    "                ts.append(t)\n",
    "                \n",
    "            ts = ''.join(ts)\n",
    "            rttm = convert_rttm(chunk)\n",
    "            textgrid = convert_textgrid(chunk)\n",
    "            \n",
    "            data.append({\n",
    "                'question': 'diarize the audio using whisper format',\n",
    "                'answer': ts,\n",
    "                'audio_filename': audio_filename,\n",
    "            })\n",
    "            data.append({\n",
    "                'question': 'diarize the audio using rttm format',\n",
    "                'answer': rttm,\n",
    "                'audio_filename': audio_filename,\n",
    "            })\n",
    "            data.append({\n",
    "                'question': 'diarize the audio using textgrid format',\n",
    "                'answer': textgrid,\n",
    "                'audio_filename': audio_filename,\n",
    "            })\n",
    "            \n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2f1d4a50",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:19<00:00,  4.06it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:24<00:00,  3.23it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  3.16it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:25<00:00,  3.06it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:26<00:00,  2.99it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:24<00:00,  3.27it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:25<00:00,  3.13it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:24<00:00,  3.20it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:25<00:00,  3.13it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:25<00:00,  3.10it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:25<00:00,  3.08it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:25<00:00,  3.11it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:25<00:00,  3.14it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:25<00:00,  3.09it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:26<00:00,  3.01it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:26<00:00,  3.02it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:26<00:00,  2.99it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:26<00:00,  2.98it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:26<00:00,  2.96it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:27<00:00,  2.92it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:26<00:00,  2.95it/s]\n"
     ]
    }
   ],
   "source": [
    "indices = list(range(len(ds['train'])))\n",
    "indices = [(i, 'train') for i in indices]\n",
    "prepared = multiprocessing(indices, loop, cores = min(len(indices), 20))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "d78101b4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10878"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(prepared)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "8c024481",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "pd.DataFrame(prepared).to_parquet('synthetic-speaker-diarization-dataset.parquet')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "59e4627e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Uploading files using Xet Storage..\n",
      "Uploading...: 100%|█████████████████████████| 1.89M/1.89M [00:04<00:00, 403kB/s]\n",
      "https://huggingface.co/datasets/mesolitica/Speaker-Diarization-Instructions/blob/main//data/synthetic_speaker_diarization_dataset-00000-of-00001.parquet\n"
     ]
    }
   ],
   "source": [
    "!huggingface-cli upload mesolitica/Speaker-Diarization-Instructions \\\n",
    "synthetic-speaker-diarization-dataset.parquet /data/synthetic_speaker_diarization_dataset-00000-of-00001.parquet \\\n",
    "--repo-type=dataset"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
